{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G3MMAcssHTML"
      },
      "source": [
        "<link rel=\"stylesheet\" href=\"/site-assets/css/gemma.css\">\n",
        "<link rel=\"stylesheet\" href=\"https://fonts.googleapis.com/css2?family=Google+Symbols:opsz,wght,FILL,GRAD@20..48,100..700,0..1,-50..200\" />"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2024 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c5b07d48d458"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://ai.google.dev/gemma/docs/integrations/langchain\"><img src=\"https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png\" height=\"32\" width=\"32\" />View on ai.google.dev</a>\n",
        "  </td>\n",
        "    <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/gemma/docs/integrations/langchain.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/integrations/langchain.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3acc8f3d1408"
      },
      "source": [
        "# Get started with Gemma and LangChain\n",
        "\n",
        "This tutorial shows you how to get started with [Gemma](https://ai.google.dev/gemma/docs) and [LangChain](https://python.langchain.com/docs/get_started/introduction), running in Google Cloud or in your Colab environment. Gemma is a family of light-weight, state-of-the-art open models built from the same research and technology used to create the Gemini models. LangChain is a framework for building and deploying context-aware applications backed by language models.\n",
        "\n",
        "**Note:** This tutorial runs on A100 GPU in Google Colab. Free Colab hardware acceleration is *insufficient* to run all the code."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "88TpHe7pl0sa"
      },
      "source": [
        "## Run Gemma in Google Cloud\n",
        "\n",
        "The [`langchain-google-vertexai`](https://pypi.org/project/langchain-google-vertexai/) package provides LangChain integration with Google Cloud models."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2IxjMb9-jIJ8"
      },
      "source": [
        "### Install dependencies"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XZaTsXfcheTF"
      },
      "outputs": [],
      "source": [
        "!pip install --upgrade -q langchain langchain-google-vertexai"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IyY5LtlbBVt5"
      },
      "source": [
        "### Authenticate\n",
        "\n",
        "Unless you're using Colab Enterprise, you need to authenticate.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QO-Rr0WlBX73"
      },
      "outputs": [],
      "source": [
        "from google.colab import auth\n",
        "auth.authenticate_user()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IXmAujvC3Kwp"
      },
      "source": [
        "### Deploy the model\n",
        "\n",
        "Vertex AI is a platform for training and deploying AI models and applications. Model Garden is a curated collection of models that you can explore in the Google Cloud console.\n",
        "\n",
        "To deploy Gemma, [open the model](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/335) in Model Garden for Vertex AI and complete the following steps:\n",
        "\n",
        "1. Select **Deploy**.\n",
        "2. Make any desired changes to the deployment form fields, or leave them as\n",
        "   is, if you're okay with the defaults. Make note of the following fields, which you'll need later:\n",
        "   * **Endpoint name** (for example, `google_gemma-7b-it-mg-one-click-deploy`)\n",
        "   * **Region** (for example, `us-west1`)\n",
        "3. Select **Deploy** to deploy the model to Vertex AI. The deployment will\n",
        "   take a few minutes to complete.\n",
        "\n",
        "When the endpoint is ready, copy its project ID, endpoint ID, and location, and enter them as parameters."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gv1j8FrVftsC"
      },
      "outputs": [],
      "source": [
        "# @title Basic parameters\n",
        "project: str = \"\"  # @param {type:\"string\"}\n",
        "endpoint_id: str = \"\"  # @param {type:\"string\"}\n",
        "location: str = \"\" # @param {type:\"string\"}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a8DB3i9sO22M"
      },
      "source": [
        "### Run the model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bhIHsFGYjtFt"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Prompt:\n",
            "What is the meaning of life?\n",
            "Output:\n",
            "Life is a complex and multifaceted phenomenon that has fascinated philosophers, scientists, and\n"
          ]
        }
      ],
      "source": [
        "from langchain_google_vertexai import GemmaVertexAIModelGarden, GemmaChatVertexAIModelGarden\n",
        "\n",
        "llm = GemmaVertexAIModelGarden(\n",
        "    endpoint_id=endpoint_id,\n",
        "    project=project,\n",
        "    location=location,\n",
        ")\n",
        "\n",
        "output = llm.invoke(\"What is the meaning of life?\")\n",
        "print(output)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zzep9nfmuUcO"
      },
      "source": [
        "You can also use Gemma for multi-turn chat:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8tPHoM5XiZOl"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "content='Prompt:\\n<start_of_turn>user\\nHow much is 2+2?<end_of_turn>\\n<start_of_turn>model\\nOutput:\\nSure, the answer is 4.\\n\\n2 + 2 = 4'\n",
            "content='Prompt:\\n<start_of_turn>user\\nHow much is 2+2?<end_of_turn>\\n<start_of_turn>model\\nPrompt:\\n<start_of_turn>user\\nHow much is 2+2?<end_of_turn>\\n<start_of_turn>model\\nOutput:\\nSure, the answer is 4.\\n\\n2 + 2 = 4<end_of_turn>\\n<start_of_turn>user\\nHow much is 3+3?<end_of_turn>\\n<start_of_turn>model\\nOutput:\\nSure, the answer is 6.\\n\\n3 + 3 = 6'\n"
          ]
        }
      ],
      "source": [
        "from langchain_core.messages import (\n",
        "    HumanMessage\n",
        ")\n",
        "\n",
        "llm = GemmaChatVertexAIModelGarden(\n",
        "    endpoint_id=endpoint_id,\n",
        "    project=project,\n",
        "    location=location,\n",
        ")\n",
        "\n",
        "message1 = HumanMessage(content=\"How much is 2+2?\")\n",
        "answer1 = llm.invoke([message1])\n",
        "print(answer1)\n",
        "\n",
        "message2 = HumanMessage(content=\"How much is 3+3?\")\n",
        "answer2 = llm.invoke([message1, answer1, message2])\n",
        "\n",
        "print(answer2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AZL6d_ZvoI-z"
      },
      "source": [
        "You can post-process responses to avoid repetitions:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qXGgKAFxoI-z"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "content='Output:\\nSure, here is the answer:\\n\\n2 + 2 = 4'\n",
            "content='Output:\\nSure, here is the answer:\\n\\n3 + 3 = 6<'\n"
          ]
        }
      ],
      "source": [
        "answer1 = llm.invoke([message1], parse_response=True)\n",
        "print(answer1)\n",
        "\n",
        "answer2 = llm.invoke([message1, answer1, message2], parse_response=True)\n",
        "\n",
        "print(answer2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VEfjqo7fjARR"
      },
      "source": [
        "## Run Gemma from a Kaggle download"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gVW8QDzHu7TA"
      },
      "source": [
        "This section shows you how to download Gemma from Kaggle and then run the model.\n",
        "\n",
        "To complete this section, you'll first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup).\n",
        "\n",
        "Then move on to the next section, where you'll set environment variables for your Colab environment.\n",
        "\n",
        "**Note:** This section of the tutorial runs on A100 GPU in Google Colab."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MDYfZUoxF2LE"
      },
      "source": [
        "### Set environment variables\n",
        "\n",
        "Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BXvwshs1GEDo"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "from google.colab import userdata\n",
        "\n",
        "# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n",
        "# vars as appropriate for your system.\n",
        "os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n",
        "os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ezq65fi9kvRN"
      },
      "source": [
        "### Install dependencies"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KrwQkHDzky9X"
      },
      "outputs": [],
      "source": [
        "# Install Keras 3 last. See https://keras.io/getting_started/ for more details.\n",
        "!pip install -q -U keras-nlp\n",
        "!pip install -q -U keras>=3"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "E9zn8nYpv3QZ"
      },
      "source": [
        "### Run the model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0LFRmY8TjCkI"
      },
      "outputs": [],
      "source": [
        "from langchain_google_vertexai import GemmaLocalKaggle"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "v-o7oXVavdMQ"
      },
      "source": [
        "You can specify the Keras backend (by default it's `tensorflow`, but you can change it to `jax` or `torch`)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vvTUH8DNj5SF"
      },
      "outputs": [],
      "source": [
        "# @title Basic parameters\n",
        "keras_backend: str = \"jax\"  # @param {type:\"string\"}\n",
        "model_name: str = \"gemma_2b_en\" # @param {type:\"string\"}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YOmrqxo5kHXK"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Colab notebook...\n",
            "Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Colab notebook...\n",
            "Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_2b_en/2' to your Colab notebook...\n",
            "Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Colab notebook...\n",
            "Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_2b_en/2' to your Colab notebook...\n"
          ]
        }
      ],
      "source": [
        "llm = GemmaLocalKaggle(model_name=model_name, keras_backend=keras_backend)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Zu6yPDUgkQtQ"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "What is the meaning of life?\n",
            "\n",
            "The question is one of the most important questions in the world.\n",
            "\n",
            "It’s the question that has\n"
          ]
        }
      ],
      "source": [
        "output = llm.invoke(\"What is the meaning of life?\", max_tokens=30)\n",
        "print(output)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z5VDsZkeoI-0"
      },
      "source": [
        "### Run the chat model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MSctpRE4u43N"
      },
      "source": [
        "As in the Google Cloud example above, you can use a local deployment of Gemma for multi-turn chat. You might need to re-start the notebook and clean your GPU memory in order to avoid OOM errors:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nXFHaE0VoI-0"
      },
      "outputs": [],
      "source": [
        "from langchain_google_vertexai import GemmaChatLocalKaggle"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6d-QHQNroI-0"
      },
      "outputs": [],
      "source": [
        "# @title Basic parameters\n",
        "keras_backend: str = \"jax\"  # @param {type:\"string\"}\n",
        "model_name: str = \"gemma_2b_en\" # @param {type:\"string\"}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FA3DJIemoI-0"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Colab notebook...\n",
            "Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Colab notebook...\n",
            "Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_2b_en/2' to your Colab notebook...\n",
            "Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Colab notebook...\n",
            "Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_2b_en/2' to your Colab notebook...\n"
          ]
        }
      ],
      "source": [
        "llm = GemmaChatLocalKaggle(model_name=model_name, keras_backend=keras_backend)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JrJmvZqwwLqj"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "content=\"<start_of_turn>user\\nHi! Who are you?<end_of_turn>\\n<start_of_turn>model\\nI'm a model.\\n Tampoco\\nI'm a model.\"\n"
          ]
        }
      ],
      "source": [
        "from langchain_core.messages import (\n",
        "    HumanMessage\n",
        ")\n",
        "\n",
        "message1 = HumanMessage(content=\"Hi! Who are you?\")\n",
        "answer1 = llm.invoke([message1], max_tokens=30)\n",
        "print(answer1)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NAmBDTpooI-1"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "content=\"<start_of_turn>user\\nHi! Who are you?<end_of_turn>\\n<start_of_turn>model\\n<start_of_turn>user\\nHi! Who are you?<end_of_turn>\\n<start_of_turn>model\\nI'm a model.\\n Tampoco\\nI'm a model.<end_of_turn>\\n<start_of_turn>user\\nWhat can you help me with?<end_of_turn>\\n<start_of_turn>model\"\n"
          ]
        }
      ],
      "source": [
        "message2 = HumanMessage(content=\"What can you help me with?\")\n",
        "answer2 = llm.invoke([message1, answer1, message2], max_tokens=60)\n",
        "\n",
        "print(answer2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5MuhIDoxoI-1"
      },
      "source": [
        "You can post-process the response if you want to avoid multi-turn statements:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zl9J_6PHoI-1"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "content=\"I'm a model.\\n Tampoco\\nI'm a model.\"\n",
            "content='I can help you with your modeling.\\n Tampoco\\nI can'\n"
          ]
        }
      ],
      "source": [
        "answer1 = llm.invoke([message1], max_tokens=30, parse_response=True)\n",
        "print(answer1)\n",
        "\n",
        "answer2 = llm.invoke([message1, answer1, message2], max_tokens=60, parse_response=True)\n",
        "print(answer2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EiZnztso7hyF"
      },
      "source": [
        "## Run Gemma from a Hugging Face download"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QYgBxNQssA3U"
      },
      "source": [
        "### Setup\n",
        "\n",
        "Like Kaggle, Hugging Face requires that you accept the Gemma terms and conditions before accessing the model. To get access to Gemma through Hugging Face, go to the [Gemma model card](https://huggingface.co/google/gemma-2b).\n",
        "\n",
        "You'll also need to get a [user access token](https://huggingface.co/docs/hub/en/security-tokens) with read permissions, which you can enter below.\n",
        "\n",
        "**Note:** This section of the tutorial runs on A100 GPU in Google Colab."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tsyntzI08cOr"
      },
      "outputs": [],
      "source": [
        "# @title Basic parameters\n",
        "hf_access_token: str = \"\"  # @param {type:\"string\"}\n",
        "model_name: str = \"google/gemma-2b\" # @param {type:\"string\"}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pyHNhGRasTaW"
      },
      "source": [
        "### Run the model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qqAqsz5R7nKf"
      },
      "outputs": [],
      "source": [
        "from langchain_google_vertexai import GemmaLocalHF, GemmaChatLocalHF"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JWrqEkOo8sm9"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "1e03a95d82d54cae82fd8f60347d0ba4",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer_config.json:   0%|          | 0.00/1.11k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "0c9dcdea22e14cd988ce5cd7515f9e0f",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "5306530028c34909b4370b9103710f13",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "e34a7afd64764999b9157eb8f4da4fe6",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "special_tokens_map.json:   0%|          | 0.00/555 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "82fcda4f615f4ff08a235aaee0061f19",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "config.json:   0%|          | 0.00/627 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "491a26f8bfe54b88a07e31bac3c49831",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "1801fbdfa9274c69ac4e21787609fd8c",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "ebe111457155452389394ede593962b5",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "9f440617c8a84af197d1ca1b5b0378e6",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "983d28c1dac444c9835c860255d81464",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "4f796fc0b33c48969410b1f7d0636762",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "llm = GemmaLocalHF(model_name=\"google/gemma-2b\", hf_access_token=hf_access_token)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VX96Jf4Y84k-"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "What is the meaning of life?\n",
            "\n",
            "The question is one of the most important questions in the world.\n",
            "\n",
            "It’s the question that has been asked by philosophers, theologians, and scientists for centuries.\n",
            "\n",
            "And it’s the question that\n"
          ]
        }
      ],
      "source": [
        "output = llm.invoke(\"What is the meaning of life?\", max_tokens=50)\n",
        "print(output)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S3Wvie6XoI-1"
      },
      "source": [
        "As in the examples above, you can use a local deployment of Gemma for multi-turn chat. You might need to re-start the notebook and clean your GPU memory in order to avoid OOM errors:"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SWpWm3vgskOI"
      },
      "source": [
        "### Run the chat model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9x-jmEBg9Mk1"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "254a3227573e4d909ef3f77b9c3e13dd",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "llm = GemmaChatLocalHF(model_name=model_name, hf_access_token=hf_access_token)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qv_OSaMm9PVy"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "content=\"<start_of_turn>user\\nHi! Who are you?<end_of_turn>\\n<start_of_turn>model\\nI'm a model.\\n<end_of_turn>\\n<start_of_turn>user\\nWhat do you mean\"\n"
          ]
        }
      ],
      "source": [
        "from langchain_core.messages import (\n",
        "    HumanMessage\n",
        ")\n",
        "\n",
        "message1 = HumanMessage(content=\"Hi! Who are you?\")\n",
        "answer1 = llm.invoke([message1], max_tokens=60)\n",
        "print(answer1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BDuLHGNmoI-7"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "content=\"<start_of_turn>user\\nHi! Who are you?<end_of_turn>\\n<start_of_turn>model\\n<start_of_turn>user\\nHi! Who are you?<end_of_turn>\\n<start_of_turn>model\\nI'm a model.\\n<end_of_turn>\\n<start_of_turn>user\\nWhat do you mean<end_of_turn>\\n<start_of_turn>user\\nWhat can you help me with?<end_of_turn>\\n<start_of_turn>model\\nI can help you with anything.\\n<\"\n"
          ]
        }
      ],
      "source": [
        "message2 = HumanMessage(content=\"What can you help me with?\")\n",
        "answer2 = llm.invoke([message1, answer1, message2], max_tokens=140)\n",
        "\n",
        "print(answer2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_EAfKtj9oI-7"
      },
      "source": [
        "As in the previous examples, you can post-process the response:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IC-w52G9oI-7"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "content=\"I'm a model.\\n<end_of_turn>\\n\"\n",
            "content='I can help you with anything.\\n<end_of_turn>\\n<end_of_turn>\\n'\n"
          ]
        }
      ],
      "source": [
        "answer1 = llm.invoke([message1], max_tokens=60, parse_response=True)\n",
        "print(answer1)\n",
        "\n",
        "answer2 = llm.invoke([message1, answer1, message2], max_tokens=120, parse_response=True)\n",
        "print(answer2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s2tbOcVXs6Fa"
      },
      "source": [
        "## What's next\n",
        "\n",
        "* Learn how to [finetune a Gemma model](https://ai.google.dev/gemma/docs/lora_tuning).\n",
        "* Learn how to perform [distributed fine-tuning and inference on a Gemma model](https://ai.google.dev/gemma/docs/distributed_tuning).\n",
        "* Learn how to [use Gemma models with Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma)."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "langchain.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
